{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "f1f7111a",
   "metadata": {},
   "source": [
    "# 1-2,图片数据建模流程范例"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d66df564",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import datetime\n",
    "\n",
    "#打印时间\n",
    "def printbar():\n",
    "    nowtime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')\n",
    "    print(\"\\n\"+\"==========\"*8 + \"%s\"%nowtime)\n",
    "\n",
    "#mac系统上pytorch和matplotlib在jupyter中同时跑需要更改环境变量\n",
    "os.environ[\"KMP_DUPLICATE_LIB_OK\"]=\"TRUE\" \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1bd73b8b",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install torchvison==0.11.2\n",
    "!pip install torchkeras==3.2.3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0ecd187d",
   "metadata": {
    "lines_to_next_cell": 2
   },
   "outputs": [],
   "source": [
    "import torch \n",
    "import torchvision \n",
    "import torchkeras \n",
    "print(\"torch.__version__ = \", torch.__version__)\n",
    "print(\"torchvision.__version__ = \", torchvision.__version__) \n",
    "print(\"torchkeras.__version__ = \", torchkeras.__version__) "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "75f8960e",
   "metadata": {},
   "source": [
    "```\n",
    "torch.__version__ =  1.10.0\n",
    "torchvision.__version__ =  0.11.2\n",
    "torchkeras.__version__ =  3.2.3\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "95b1977b",
   "metadata": {},
   "source": [
    "<br>\n",
    "\n",
    "<font color=\"red\">\n",
    " \n",
    "公众号 **算法美食屋** 回复关键词：**pytorch**， 获取本项目源码和所用数据集百度云盘下载链接。\n",
    "    \n",
    "</font> \n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7b1bde63",
   "metadata": {},
   "source": [
    "### 一，准备数据"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "15a18555",
   "metadata": {},
   "source": [
    "cifar2数据集为cifar10数据集的子集，只包括前两种类别airplane和automobile。\n",
    "\n",
    "训练集有airplane和automobile图片各5000张，测试集有airplane和automobile图片各1000张。\n",
    "\n",
    "cifar2任务的目标是训练一个模型来对飞机airplane和机动车automobile两种图片进行分类。\n",
    "\n",
    "我们准备的Cifar2数据集的文件结构如下所示。\n",
    "\n",
    "![](./data/cifar2.jpg)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e5d38f2e",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "16ecef43",
   "metadata": {},
   "source": [
    "在Pytorch中构建图片数据管道通常有两种方法。\n",
    "\n",
    "第一种是使用 torchvision中的datasets.ImageFolder来读取图片然后用 DataLoader来并行加载。\n",
    "\n",
    "第二种是通过继承 torch.utils.data.Dataset 实现用户自定义读取逻辑然后用 DataLoader来并行加载。\n",
    "\n",
    "第二种方法是读取用户自定义数据集的通用方法，既可以读取图片数据集，也可以读取文本数据集。\n",
    "\n",
    "本篇我们介绍第一种方法。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "61535b2e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch \n",
    "from torch import nn\n",
    "from torch.utils.data import Dataset,DataLoader\n",
    "from torchvision import transforms as T\n",
    "from torchvision import datasets "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a242a652",
   "metadata": {},
   "outputs": [],
   "source": [
    "transform_img = T.Compose(\n",
    "    [T.ToTensor()])\n",
    "\n",
    "def transform_label(x):\n",
    "    return torch.tensor([x]).float()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "53f01ce9",
   "metadata": {},
   "outputs": [],
   "source": [
    "ds_train = datasets.ImageFolder(\"./eat_pytorch_datasets/cifar2/train/\",\n",
    "            transform = transform_img,target_transform = transform_label)\n",
    "ds_val = datasets.ImageFolder(\"./eat_pytorch_datasets/cifar2/test/\",\n",
    "            transform = transform_img,target_transform = transform_label)\n",
    "print(ds_train.class_to_idx)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "70f47187",
   "metadata": {},
   "source": [
    "```\n",
    "{'0_airplane': 0, '1_automobile': 1}\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0aec4c82",
   "metadata": {},
   "outputs": [],
   "source": [
    "dl_train = DataLoader(ds_train,batch_size = 50,shuffle = True)\n",
    "dl_val = DataLoader(ds_val,batch_size = 50,shuffle = False)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "de412f15",
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "%config InlineBackend.figure_format = 'svg'\n",
    "\n",
    "#查看部分样本\n",
    "from matplotlib import pyplot as plt \n",
    "\n",
    "plt.figure(figsize=(8,8)) \n",
    "for i in range(9):\n",
    "    img,label = ds_train[i]\n",
    "    img = img.permute(1,2,0)\n",
    "    ax=plt.subplot(3,3,i+1)\n",
    "    ax.imshow(img.numpy())\n",
    "    ax.set_title(\"label = %d\"%label.item())\n",
    "    ax.set_xticks([])\n",
    "    ax.set_yticks([]) \n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "443f6581",
   "metadata": {},
   "source": [
    "![](./data/1-2-查看样本.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "367ccf3b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Pytorch的图片默认顺序是 Batch,Channel,Width,Height\n",
    "for features,labels in dl_train:\n",
    "    print(features.shape,labels.shape) \n",
    "    break\n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0dfed26b",
   "metadata": {},
   "source": [
    "```\n",
    "torch.Size([50, 3, 32, 32]) torch.Size([50, 1])\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8c614279",
   "metadata": {},
   "source": [
    "### 二，定义模型"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d496cb9e",
   "metadata": {},
   "source": [
    "使用Pytorch通常有三种方式构建模型：使用nn.Sequential按层顺序构建模型，继承nn.Module基类构建自定义模型，继承nn.Module基类构建模型并辅助应用模型容器(nn.Sequential,nn.ModuleList,nn.ModuleDict)进行封装。\n",
    "\n",
    "此处选择通过继承nn.Module基类构建自定义模型。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dde64c23",
   "metadata": {},
   "outputs": [],
   "source": [
    "#测试AdaptiveMaxPool2d的效果\n",
    "pool = nn.AdaptiveMaxPool2d((1,1))\n",
    "t = torch.randn(10,8,32,32)\n",
    "pool(t).shape "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "87369fff",
   "metadata": {},
   "source": [
    "```\n",
    "torch.Size([10, 8, 1, 1])\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "39d23d45",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Net(nn.Module):\n",
    "    \n",
    "    def __init__(self):\n",
    "        super(Net, self).__init__()\n",
    "        self.conv1 = nn.Conv2d(in_channels=3,out_channels=32,kernel_size = 3)\n",
    "        self.pool = nn.MaxPool2d(kernel_size = 2,stride = 2)\n",
    "        self.conv2 = nn.Conv2d(in_channels=32,out_channels=64,kernel_size = 5)\n",
    "        self.dropout = nn.Dropout2d(p = 0.1)\n",
    "        self.adaptive_pool = nn.AdaptiveMaxPool2d((1,1))\n",
    "        self.flatten = nn.Flatten()\n",
    "        self.linear1 = nn.Linear(64,32)\n",
    "        self.relu = nn.ReLU()\n",
    "        self.linear2 = nn.Linear(32,1)\n",
    "        \n",
    "    def forward(self,x):\n",
    "        x = self.conv1(x)\n",
    "        x = self.pool(x)\n",
    "        x = self.conv2(x)\n",
    "        x = self.pool(x)\n",
    "        x = self.dropout(x)\n",
    "        x = self.adaptive_pool(x)\n",
    "        x = self.flatten(x)\n",
    "        x = self.linear1(x)\n",
    "        x = self.relu(x)\n",
    "        x = self.linear2(x)\n",
    "        return x \n",
    "        \n",
    "net = Net()\n",
    "print(net)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b0dcc7ce",
   "metadata": {},
   "source": [
    "```\n",
    "Net(\n",
    "  (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1))\n",
    "  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
    "  (conv2): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1))\n",
    "  (dropout): Dropout2d(p=0.1, inplace=False)\n",
    "  (adaptive_pool): AdaptiveMaxPool2d(output_size=(1, 1))\n",
    "  (flatten): Flatten(start_dim=1, end_dim=-1)\n",
    "  (linear1): Linear(in_features=64, out_features=32, bias=True)\n",
    "  (relu): ReLU()\n",
    "  (linear2): Linear(in_features=32, out_features=1, bias=True)\n",
    ")\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "98f6d94d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torchkeras\n",
    "torchkeras.summary(net,input_data = features);"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8420cd5c",
   "metadata": {},
   "source": [
    "```\n",
    "--------------------------------------------------------------------------\n",
    "Layer (type)                            Output Shape              Param #\n",
    "==========================================================================\n",
    "Conv2d-1                            [-1, 32, 30, 30]                  896\n",
    "MaxPool2d-2                         [-1, 32, 15, 15]                    0\n",
    "Conv2d-3                            [-1, 64, 11, 11]               51,264\n",
    "MaxPool2d-4                           [-1, 64, 5, 5]                    0\n",
    "Dropout2d-5                           [-1, 64, 5, 5]                    0\n",
    "AdaptiveMaxPool2d-6                   [-1, 64, 1, 1]                    0\n",
    "Flatten-7                                   [-1, 64]                    0\n",
    "Linear-8                                    [-1, 32]                2,080\n",
    "ReLU-9                                      [-1, 32]                    0\n",
    "Linear-10                                    [-1, 1]                   33\n",
    "Net-11                                       [-1, 1]               54,273\n",
    "==========================================================================\n",
    "Total params: 108,546\n",
    "Trainable params: 108,546\n",
    "Non-trainable params: 0\n",
    "--------------------------------------------------------------------------\n",
    "Input size (MB): 0.000069\n",
    "Forward/backward pass size (MB): 0.359634\n",
    "Params size (MB): 0.414070\n",
    "Estimated Total Size (MB): 0.773773\n",
    "--------------------------------------------------------------------------\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a141ab6f",
   "metadata": {},
   "source": [
    "### 三，训练模型"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b1dbb766",
   "metadata": {},
   "source": [
    "Pytorch通常需要用户编写自定义训练循环，训练循环的代码风格因人而异。\n",
    "\n",
    "有3类典型的训练循环代码风格：脚本形式训练循环，函数形式训练循环，类形式训练循环。\n",
    "\n",
    "此处介绍一种较通用的仿照Keras风格的函数形式的训练循环。\n",
    "\n",
    "该训练循环的代码也是torchkeras库的核心代码。\n",
    "\n",
    "torchkeras详情:  https://github.com/lyhue1991/torchkeras \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5beafb9e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os,sys,time\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import datetime \n",
    "from tqdm import tqdm \n",
    "\n",
    "import torch\n",
    "from torch import nn \n",
    "from copy import deepcopy\n",
    "\n",
    "def printlog(info):\n",
    "    nowtime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')\n",
    "    print(\"\\n\"+\"==========\"*8 + \"%s\"%nowtime)\n",
    "    print(str(info)+\"\\n\")\n",
    "\n",
    "class StepRunner:\n",
    "    def __init__(self, net, loss_fn,\n",
    "                 stage = \"train\", metrics_dict = None, \n",
    "                 optimizer = None\n",
    "                 ):\n",
    "        self.net,self.loss_fn,self.metrics_dict,self.stage = net,loss_fn,metrics_dict,stage\n",
    "        self.optimizer = optimizer\n",
    "            \n",
    "    def step(self, features, labels):\n",
    "        #loss\n",
    "        preds = self.net(features)\n",
    "        loss = self.loss_fn(preds,labels)\n",
    "        \n",
    "        #backward()\n",
    "        if self.optimizer is not None and self.stage==\"train\": \n",
    "            loss.backward()\n",
    "            self.optimizer.step()\n",
    "            self.optimizer.zero_grad()\n",
    "            \n",
    "        #metrics\n",
    "        step_metrics = {self.stage+\"_\"+name:metric_fn(preds, labels).item() \n",
    "                        for name,metric_fn in self.metrics_dict.items()}\n",
    "        return loss.item(),step_metrics\n",
    "    \n",
    "    def train_step(self,features,labels):\n",
    "        self.net.train() #训练模式, dropout层发生作用\n",
    "        return self.step(features,labels)\n",
    "    \n",
    "    @torch.no_grad()\n",
    "    def eval_step(self,features,labels):\n",
    "        self.net.eval() #预测模式, dropout层不发生作用\n",
    "        return self.step(features,labels)\n",
    "    \n",
    "    def __call__(self,features,labels):\n",
    "        if self.stage==\"train\":\n",
    "            return self.train_step(features,labels) \n",
    "        else:\n",
    "            return self.eval_step(features,labels)\n",
    "        \n",
    "class EpochRunner:\n",
    "    def __init__(self,steprunner):\n",
    "        self.steprunner = steprunner\n",
    "        self.stage = steprunner.stage\n",
    "        \n",
    "    def __call__(self,dataloader):\n",
    "        total_loss,step = 0,0\n",
    "        loop = tqdm(enumerate(dataloader), total =len(dataloader))\n",
    "        for i, batch in loop: \n",
    "            loss, step_metrics = self.steprunner(*batch)\n",
    "            step_log = dict({self.stage+\"_loss\":loss},**step_metrics)\n",
    "            total_loss += loss\n",
    "            step+=1\n",
    "            if i!=len(dataloader)-1:\n",
    "                loop.set_postfix(**step_log)\n",
    "            else:\n",
    "                epoch_loss = total_loss/step\n",
    "                epoch_metrics = {self.stage+\"_\"+name:metric_fn.compute().item() \n",
    "                                 for name,metric_fn in self.steprunner.metrics_dict.items()}\n",
    "                epoch_log = dict({self.stage+\"_loss\":epoch_loss},**epoch_metrics)\n",
    "                loop.set_postfix(**epoch_log)\n",
    "\n",
    "                for name,metric_fn in self.steprunner.metrics_dict.items():\n",
    "                    metric_fn.reset()\n",
    "        return epoch_log\n",
    "\n",
    "\n",
    "def train_model(net, optimizer, loss_fn, metrics_dict, \n",
    "                train_data, val_data=None, \n",
    "                epochs=10, ckpt_path='checkpoint.pt',\n",
    "                patience=5, monitor=\"val_loss\", mode=\"min\"):\n",
    "    \n",
    "    history = {}\n",
    "\n",
    "    for epoch in range(1, epochs+1):\n",
    "        printlog(\"Epoch {0} / {1}\".format(epoch, epochs))\n",
    "\n",
    "        # 1，train -------------------------------------------------  \n",
    "        train_step_runner = StepRunner(net = net,stage=\"train\",\n",
    "                loss_fn = loss_fn,metrics_dict=deepcopy(metrics_dict),\n",
    "                optimizer = optimizer)\n",
    "        train_epoch_runner = EpochRunner(train_step_runner)\n",
    "        train_metrics = train_epoch_runner(train_data)\n",
    "\n",
    "        for name, metric in train_metrics.items():\n",
    "            history[name] = history.get(name, []) + [metric]\n",
    "\n",
    "        # 2，validate -------------------------------------------------\n",
    "        if val_data:\n",
    "            val_step_runner = StepRunner(net = net,stage=\"val\",\n",
    "                loss_fn = loss_fn,metrics_dict=deepcopy(metrics_dict))\n",
    "            val_epoch_runner = EpochRunner(val_step_runner)\n",
    "            with torch.no_grad():\n",
    "                val_metrics = val_epoch_runner(val_data)\n",
    "            val_metrics[\"epoch\"] = epoch\n",
    "            for name, metric in val_metrics.items():\n",
    "                history[name] = history.get(name, []) + [metric]\n",
    "\n",
    "        # 3，early-stopping -------------------------------------------------\n",
    "        arr_scores = history[monitor]\n",
    "        best_score_idx = np.argmax(arr_scores) if mode==\"max\" else np.argmin(arr_scores)\n",
    "        if best_score_idx==len(arr_scores)-1:\n",
    "            torch.save(net.state_dict(),ckpt_path)\n",
    "            print(\"<<<<<< reach best {0} : {1} >>>>>>\".format(monitor,\n",
    "                 arr_scores[best_score_idx]),file=sys.stderr)\n",
    "        if len(arr_scores)-best_score_idx>patience:\n",
    "            print(\"<<<<<< {} without improvement in {} epoch, early stopping >>>>>>\".format(\n",
    "                monitor,patience),file=sys.stderr)\n",
    "            break \n",
    "        net.load_state_dict(torch.load(ckpt_path))\n",
    "\n",
    "    return pd.DataFrame(history)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "62192136",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "76a95080",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torchmetrics \n",
    "\n",
    "class Accuracy(torchmetrics.Accuracy):\n",
    "    def __init__(self, dist_sync_on_step=False):\n",
    "        super().__init__(dist_sync_on_step=dist_sync_on_step)\n",
    "        \n",
    "    def update(self, preds: torch.Tensor, targets: torch.Tensor):\n",
    "        super().update(torch.sigmoid(preds),targets.long())\n",
    "            \n",
    "    def compute(self):\n",
    "        return super().compute()\n",
    "    \n",
    "    \n",
    "loss_fn = nn.BCEWithLogitsLoss()\n",
    "optimizer= torch.optim.Adam(net.parameters(),lr = 0.01)   \n",
    "metrics_dict = {\"acc\":Accuracy()}\n",
    "\n",
    "dfhistory = train_model(net,\n",
    "    optimizer,\n",
    "    loss_fn,\n",
    "    metrics_dict,\n",
    "    train_data = dl_train,\n",
    "    val_data= dl_val,\n",
    "    epochs=10,\n",
    "    patience=5,\n",
    "    monitor=\"val_acc\", \n",
    "    mode=\"max\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bac6747f",
   "metadata": {},
   "source": [
    "```\n",
    "================================================================================2022-07-10 20:06:16\n",
    "Epoch 1 / 10\n",
    "\n",
    "100%|██████████| 200/200 [00:17<00:00, 11.74it/s, train_acc=0.735, train_loss=0.53]\n",
    "100%|██████████| 40/40 [00:01<00:00, 20.07it/s, val_acc=0.827, val_loss=0.383]\n",
    "<<<<<< reach best val_acc : 0.8274999856948853 >>>>>>\n",
    "\n",
    "================================================================================2022-07-10 20:06:35\n",
    "Epoch 2 / 10\n",
    "\n",
    "100%|██████████| 200/200 [00:16<00:00, 11.96it/s, train_acc=0.832, train_loss=0.391]\n",
    "100%|██████████| 40/40 [00:02<00:00, 18.13it/s, val_acc=0.854, val_loss=0.317]\n",
    "<<<<<< reach best val_acc : 0.8544999957084656 >>>>>>\n",
    "\n",
    "================================================================================2022-07-10 20:06:54\n",
    "Epoch 3 / 10\n",
    "\n",
    "100%|██████████| 200/200 [00:17<00:00, 11.71it/s, train_acc=0.87, train_loss=0.313]\n",
    "100%|██████████| 40/40 [00:02<00:00, 19.96it/s, val_acc=0.902, val_loss=0.239]\n",
    "<<<<<< reach best val_acc : 0.9024999737739563 >>>>>>\n",
    "\n",
    "================================================================================2022-07-10 20:07:13\n",
    "Epoch 4 / 10\n",
    "\n",
    "100%|██████████| 200/200 [00:16<00:00, 11.88it/s, train_acc=0.889, train_loss=0.265]\n",
    "100%|██████████| 40/40 [00:02<00:00, 18.46it/s, val_acc=0.91, val_loss=0.216]\n",
    "<<<<<< reach best val_acc : 0.9100000262260437 >>>>>>\n",
    "\n",
    "================================================================================2022-07-10 20:07:32\n",
    "Epoch 5 / 10\n",
    "\n",
    "100%|██████████| 200/200 [00:17<00:00, 11.71it/s, train_acc=0.902, train_loss=0.239]\n",
    "100%|██████████| 40/40 [00:02<00:00, 19.68it/s, val_acc=0.891, val_loss=0.279]\n",
    "\n",
    "================================================================================2022-07-10 20:07:51\n",
    "Epoch 6 / 10\n",
    "\n",
    "100%|██████████| 200/200 [00:17<00:00, 11.75it/s, train_acc=0.915, train_loss=0.212]\n",
    "100%|██████████| 40/40 [00:02<00:00, 19.52it/s, val_acc=0.908, val_loss=0.222]\n",
    "\n",
    "================================================================================2022-07-10 20:08:10\n",
    "Epoch 7 / 10\n",
    "\n",
    "100%|██████████| 200/200 [00:16<00:00, 11.79it/s, train_acc=0.921, train_loss=0.196]\n",
    "100%|██████████| 40/40 [00:02<00:00, 19.26it/s, val_acc=0.929, val_loss=0.187]\n",
    "<<<<<< reach best val_acc : 0.9294999837875366 >>>>>>\n",
    "\n",
    "================================================================================2022-07-10 20:08:29\n",
    "Epoch 8 / 10\n",
    "\n",
    "100%|██████████| 200/200 [00:17<00:00, 11.59it/s, train_acc=0.931, train_loss=0.175]\n",
    "100%|██████████| 40/40 [00:02<00:00, 19.91it/s, val_acc=0.938, val_loss=0.187]\n",
    "<<<<<< reach best val_acc : 0.9375 >>>>>>\n",
    "\n",
    "================================================================================2022-07-10 20:08:49\n",
    "Epoch 9 / 10\n",
    "\n",
    "100%|██████████| 200/200 [00:17<00:00, 11.68it/s, train_acc=0.929, train_loss=0.178]\n",
    "100%|██████████| 40/40 [00:02<00:00, 19.90it/s, val_acc=0.937, val_loss=0.181]\n",
    "\n",
    "================================================================================2022-07-10 20:09:08\n",
    "Epoch 10 / 10\n",
    "\n",
    "100%|██████████| 200/200 [00:16<00:00, 11.84it/s, train_acc=0.937, train_loss=0.16] \n",
    "100%|██████████| 40/40 [00:02<00:00, 19.91it/s, val_acc=0.937, val_loss=0.167]\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "94c82203",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "2534de95",
   "metadata": {},
   "source": [
    "### 四，评估模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1ead5148",
   "metadata": {},
   "outputs": [],
   "source": [
    "dfhistory "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6b0e1fde",
   "metadata": {},
   "source": [
    "```\n",
    "train_loss\ttrain_acc\tval_loss\tval_acc\tepoch\n",
    "0\t0.761911\t0.6896\t0.503468\t0.765\t1\n",
    "1\t0.500893\t0.7627\t0.403210\t0.830\t2\n",
    "2\t0.417750\t0.8128\t0.328020\t0.870\t3\n",
    "3\t0.366155\t0.8444\t0.370906\t0.814\t4\n",
    "4\t0.364717\t0.8428\t0.290701\t0.876\t5\n",
    "5\t0.610342\t0.6406\t0.693153\t0.500\t6\n",
    "6\t0.693610\t0.4976\t0.693386\t0.500\t7\n",
    "7\t0.693578\t0.5046\t0.693815\t0.500\t8\n",
    "8\t0.693735\t0.4988\t0.693718\t0.500\t9\n",
    "9\t0.693681\t0.4960\t0.694350\t0.500\t10\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ad94daa3",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d17bb6de",
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "%config InlineBackend.figure_format = 'svg'\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "def plot_metric(dfhistory, metric):\n",
    "    train_metrics = dfhistory[\"train_\"+metric]\n",
    "    val_metrics = dfhistory['val_'+metric]\n",
    "    epochs = range(1, len(train_metrics) + 1)\n",
    "    plt.plot(epochs, train_metrics, 'bo--')\n",
    "    plt.plot(epochs, val_metrics, 'ro-')\n",
    "    plt.title('Training and validation '+ metric)\n",
    "    plt.xlabel(\"Epochs\")\n",
    "    plt.ylabel(metric)\n",
    "    plt.legend([\"train_\"+metric, 'val_'+metric])\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9350e7e2",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_metric(dfhistory,\"loss\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "eb3f085a",
   "metadata": {},
   "source": [
    "![](./data/1-2-loss曲线.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d790db14",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_metric(dfhistory,\"acc\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "022747c2",
   "metadata": {},
   "source": [
    "![](./data/1-2-auc曲线.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3fc91c91",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "997c21b1",
   "metadata": {},
   "source": [
    "### 五，使用模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "49191373",
   "metadata": {},
   "outputs": [],
   "source": [
    "def predict(net,dl):\n",
    "    net.eval()\n",
    "    with torch.no_grad():\n",
    "        result = nn.Sigmoid()(torch.cat([net.forward(t[0]) for t in dl]))\n",
    "    return(result.data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a2df2bf7",
   "metadata": {},
   "outputs": [],
   "source": [
    "#预测概率\n",
    "y_pred_probs = predict(net,dl_val)\n",
    "y_pred_probs"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "11793bc5",
   "metadata": {},
   "source": [
    "```\n",
    "tensor([[3.6409e-03],\n",
    "        [3.1401e-05],\n",
    "        [1.4732e-02],\n",
    "        ...,\n",
    "        [9.6308e-01],\n",
    "        [9.9835e-01],\n",
    "        [7.8825e-01]])\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "60b72dac",
   "metadata": {},
   "outputs": [],
   "source": [
    "#预测类别\n",
    "y_pred = torch.where(y_pred_probs>0.5,\n",
    "        torch.ones_like(y_pred_probs),torch.zeros_like(y_pred_probs))\n",
    "y_pred"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "94c98285",
   "metadata": {},
   "source": [
    "```\n",
    "tensor([[0.],\n",
    "        [0.],\n",
    "        [0.],\n",
    "        ...,\n",
    "        [1.],\n",
    "        [1.],\n",
    "        [1.]])\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1bf7abad",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "0d6f7881",
   "metadata": {},
   "source": [
    "### 六，保存模型"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "63c666ad",
   "metadata": {},
   "source": [
    "推荐使用保存参数方式保存Pytorch模型。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "878b235b",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(net.state_dict().keys())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "62f87cae",
   "metadata": {},
   "source": [
    "```\n",
    "odict_keys(['conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias', 'linear1.weight', 'linear1.bias', 'linear2.weight', 'linear2.bias'])\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d128841f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 保存模型参数\n",
    "\n",
    "torch.save(net.state_dict(), \"./data/net_parameter.pt\")\n",
    "\n",
    "net_clone = Net()\n",
    "net_clone.load_state_dict(torch.load(\"./data/net_parameter.pt\"))\n",
    "\n",
    "predict(net_clone,dl_val)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "912e01c2",
   "metadata": {},
   "source": [
    "```\n",
    "tensor([[3.6409e-03],\n",
    "        [3.1401e-05],\n",
    "        [1.4732e-02],\n",
    "        ...,\n",
    "        [9.6308e-01],\n",
    "        [9.9835e-01],\n",
    "        [7.8825e-01]])\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "acc508b3",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "c6f1b8ee",
   "metadata": {},
   "source": [
    "**如果本书对你有所帮助，想鼓励一下作者，记得给本项目加一颗星星star⭐️，并分享给你的朋友们喔😊!** \n",
    "\n",
    "如果对本书内容理解上有需要进一步和作者交流的地方，欢迎在公众号\"算法美食屋\"下留言。作者时间和精力有限，会酌情予以回复。\n",
    "\n",
    "也可以在公众号后台回复关键字：**加群**，加入读者交流群和大家讨论。\n",
    "\n",
    "![算法美食屋logo.png](https://tva1.sinaimg.cn/large/e6c9d24egy1h41m2zugguj20k00b9q46.jpg)\n"
   ]
  }
 ],
 "metadata": {
  "jupytext": {
   "cell_metadata_filter": "-all",
   "formats": "ipynb,md",
   "main_language": "python"
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
